Fix incorrect reduction condition in fused_scatter_reduce Triton kernel#638
Fix incorrect reduction condition in fused_scatter_reduce Triton kernel#638Umang-projects wants to merge 4 commits intopyg-team:masterfrom
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. ❌ Your project check has failed because the head coverage (58.40%) is below the target coverage (80.00%). You can increase the head coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## master #638 +/- ##
===========================================
- Coverage 70.13% 58.40% -11.74%
===========================================
Files 37 72 +35
Lines 1671 3236 +1565
Branches 0 262 +262
===========================================
+ Hits 1172 1890 +718
- Misses 499 1344 +845
- Partials 0 2 +2 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
for more information, see https://pre-commit.ci
|
Hi @rusty1s , (Note: The Codecov check is failing due to overall target coverage, but the modified lines are fully covered). Let me know if i need to make any changes. |
Description
While reviewing the Triton implementations for
fused_scatter_reduce, I noticed a copy-paste error in theREDUCE1block of_fused_scatter_reduce_forward_kernel.Specifically, when
REDUCE1is evaluated forminormax, the kernel incorrectly checksREDUCE2 == 3andREDUCE3 == 4instead ofREDUCE1. This causes silent mathematical errors and failing reductions when multiple reduction types (including min/max) are passed in specific orders.Changes Made:
if REDUCE1 > 0:to correctly checkREDUCE1 == 3(min) andREDUCE1 == 4(max).Note: I also noticed the TODOs regarding double computation of
sum/mean, unrolling the loops (sincetl.static_unrollis now available), and adding backward passes. I plan to address those optimizations in follow-up PRs once this critical fix is merged.